{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "1fb76974-93ea-4b9c-81b1-55f826e7a361",
   "metadata": {},
   "outputs": [],
   "source": [
    "########################################################################################################\n",
    "# The RWKV Language Model - https://github.com/BlinkDL/RWKV-LM\n",
    "########################################################################################################\n",
    "\n",
    "import numpy as np\n",
    "np.set_printoptions(precision=4, suppress=True, linewidth=200)\n",
    "import types, torch\n",
    "import torch.nn as nn\n",
    "from torch.nn import functional as F\n",
    "\n",
    "MyModule = torch.jit.ScriptModule\n",
    "MyFunction = torch.jit.script_method"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "9c97049c-d3ae-4c72-bff4-d99416f8d650",
   "metadata": {},
   "source": [
    "rwkv5又叫eagal"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "b1059eca-db4f-4c0b-ae3e-37af49ec7fa1",
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "1c8d8009-7ee7-4419-aacb-cdc45f287010",
   "metadata": {},
   "outputs": [],
   "source": [
    "class RWKV_TOKENIZER():\n",
    "    table: list[list[list[bytes]]]\n",
    "    good: list[set[int]]\n",
    "    wlen: list[int]\n",
    "    def __init__(self, file_name):\n",
    "        self.idx2token = {}\n",
    "        sorted = [] # must be already sorted\n",
    "        lines = open(file_name, \"r\", encoding=\"utf-8\").readlines()\n",
    "        for l in lines:\n",
    "            idx = int(l[:l.index(' ')])\n",
    "            x = eval(l[l.index(' '):l.rindex(' ')])\n",
    "            x = x.encode(\"utf-8\") if isinstance(x, str) else x\n",
    "            assert isinstance(x, bytes)\n",
    "            assert len(x) == int(l[l.rindex(' '):])\n",
    "            sorted += [x]\n",
    "            self.idx2token[idx] = x\n",
    "\n",
    "        self.token2idx = {}\n",
    "        for k, v in self.idx2token.items():\n",
    "            self.token2idx[v] = int(k)\n",
    "\n",
    "        # precompute some tables for fast matching\n",
    "        self.table = [[[] for j in range(256)] for i in range(256)]\n",
    "        self.good = [set() for i in range(256)]\n",
    "        self.wlen = [0 for i in range(256)]\n",
    "\n",
    "        for i in reversed(range(len(sorted))): # reverse order - match longer tokens first\n",
    "            s = sorted[i]\n",
    "            if len(s) >= 2:\n",
    "                s0 = int(s[0])\n",
    "                s1 = int(s[1])\n",
    "                self.table[s0][s1] += [s]\n",
    "                self.wlen[s0] = max(self.wlen[s0], len(s))\n",
    "                self.good[s0].add(s1)\n",
    "\n",
    "    def encodeBytes(self, src: bytes) -> list[int]:\n",
    "        src_len: int = len(src)\n",
    "        tokens: list[int] = []\n",
    "        i: int = 0\n",
    "        while i < src_len:\n",
    "            s: bytes = src[i : i + 1]\n",
    "\n",
    "            if i < src_len - 1:\n",
    "                s1: int = int(src[i + 1])\n",
    "                s0: int = int(src[i])\n",
    "                if s1 in self.good[s0]:\n",
    "                    sss: bytes = src[i : i + self.wlen[s0]]\n",
    "                    try:\n",
    "                        s = next(filter(sss.startswith, self.table[s0][s1]))\n",
    "                    except:\n",
    "                        pass\n",
    "            tokens.append(self.token2idx[s])\n",
    "            i += len(s)\n",
    "\n",
    "        return tokens\n",
    "\n",
    "    def decodeBytes(self, tokens):\n",
    "        return b''.join(map(lambda i: self.idx2token[i], tokens))\n",
    "\n",
    "    def encode(self, src: str):\n",
    "        return self.encodeBytes(src.encode(\"utf-8\"))\n",
    "\n",
    "    def decode(self, tokens):\n",
    "        return self.decodeBytes(tokens).decode('utf-8')\n",
    "\n",
    "    def printTokens(self, tokens):\n",
    "        for i in tokens:\n",
    "            s = self.idx2token[i]\n",
    "            try:\n",
    "                s = s.decode('utf-8')\n",
    "            except:\n",
    "                pass\n",
    "            print(f'{repr(s)}{i}', end=' ')\n",
    "            # print(repr(s), i)\n",
    "        print()\n",
    "\n",
    "########################################################################################################"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "63a4e8ba-a291-4fdc-aef1-ebfca21840d4",
   "metadata": {},
   "outputs": [],
   "source": [
    "def sample_logits(out, temperature=1.0, top_p=0.8):\n",
    "    probs = F.softmax(out, dim=-1).numpy()\n",
    "    sorted_probs = np.sort(probs)[::-1]\n",
    "    cumulative_probs = np.cumsum(sorted_probs)\n",
    "    cutoff = float(sorted_probs[np.argmax(cumulative_probs > top_p)])\n",
    "    probs[probs < cutoff] = 0\n",
    "    if temperature != 1.0:\n",
    "        probs = probs.pow(1.0 / temperature)\n",
    "    probs = probs / np.sum(probs)\n",
    "    out = np.random.choice(a=len(probs), p=probs)\n",
    "    return out\n",
    "\n",
    "########################################################################################################"
   ]
  },
  {
   "cell_type": "raw",
   "id": "cb8c7d5e-08cb-4780-b6d9-ab8bad1417d4",
   "metadata": {},
   "source": [
    "可以从这个链接下载模型：\n",
    "https://www.modelscope.cn/models/AI-ModelScope/rwkv-5-world/files\n",
    "https://www.modelscope.cn/api/v1/models/AI-ModelScope/rwkv-5-world/repo?Revision=master&FilePath=RWKV-5-World-0.1B-v1-20230803-ctx4096.pth"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "94d7d6db-e89e-4209-ae72-6625ba85ef5b",
   "metadata": {},
   "outputs": [],
   "source": [
    "tokenizer = RWKV_TOKENIZER(\"./rwkv_vocab_v20230424.txt\")\n",
    "\n",
    "# THIS IS NOW UPDATED TO SUPPORT LATEST RWKV-5 WORLD v2 MODELS\n",
    "\n",
    "args = types.SimpleNamespace()\n",
    "args.MODEL_NAME = '/data1/ckw/RWKV-5-World-0.4B-v2-20231113-ctx4096' #这里不用有后缀.pth\n",
    "args.n_layer = 24\n",
    "args.n_embd = 1024\n",
    "args.vocab_size = 65536"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "id": "c8dcf39a-7838-454b-85fc-ec9bd75fa243",
   "metadata": {},
   "outputs": [],
   "source": [
    "# N_LAYER=\"12\"\n",
    "# N_EMBD=\"768\"\n",
    "N_LAYER=\"24\"\n",
    "N_EMBD=\"1024\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "id": "74d7c96a-6fbc-401c-8078-fefb1a6ec5c3",
   "metadata": {},
   "outputs": [],
   "source": [
    "# context = \"\\nElon Musk has\"\n",
    "# context = \"\\n我们发现\"\n",
    "context = \"Q:Do you know datawhalechina?\\nA:\"\n",
    "NUM_TRIALS = 3\n",
    "LENGTH_PER_TRIAL = 100\n",
    "LENGTH_PER_TRIAL = 4096\n",
    "TEMPERATURE = 1.0\n",
    "TOP_P = 0.7"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "5ac1c244-71e3-4263-8ad8-4d0cf681ebfd",
   "metadata": {},
   "source": [
    "Eagle (RWKV-5) 和 Finch (RWKV-6) 相较于基础的RWKV-4架构在建模上的改进：\n",
    "\n",
    "1. **改进步骤**：\n",
    "   - **Eagle的改进**：Eagle模型在RWKV-4的基础上进行了多项改进，包括引入矩阵值的注意力状态（matrix-valued attention states）、在注意力头上应用LayerNorm（层归一化）、使用SiLU（Sigmoid-Weighted Linear Unit）进行注意力门控、并改进了初始化方法。此外，Eagle移除了接受度（receptance）函数中的Sigmoid激活函数。\n",
    "   - **Finch的改进**：Finch模型进一步引入了对衰减计划（decay schedule）和令牌移位（token-shift）的数据依赖性（data-dependence），使模型在处理时间和令牌数据时更加灵活和精确。\n",
    "\n",
    "2. **核心架构**：\n",
    "   - 这些模型的核心架构依然类似于RWKV-4，由一系列堆叠的残差块组成，形状类似于传统的Transformer架构。\n",
    "   - 每个块包含一个预LayerNorm时间混合子层（Pre-LayerNorm Time-Mixing sub-layer）和一个预LayerNorm通道混合子层（Pre-LayerNorm Channel-Mixing sub-layer），对应于Transformer中的注意力子层和前馈网络子层。\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "bd3d56d6-59af-41d0-9ac3-2cc5b4fb54ed",
   "metadata": {},
   "source": [
    "这个是RWKV 5的Channel Mixing的代码实现，可以对比一下RWKV 4的实现。\n",
    "\n",
    "\n",
    "```python\n",
    "@MyFunction\n",
    "    def channel_mixing(self, x, state, i:int, time_mix_k, time_mix_r, kw, vw, rw):\n",
    "        i0 = (2+self.head_size)*i+0\n",
    "        xk = x * time_mix_k + state[i0] * (1 - time_mix_k)\n",
    "        xr = x * time_mix_r + state[i0] * (1 - time_mix_r)\n",
    "        state[i0] = x\n",
    "        r = torch.sigmoid(rw @ xr)\n",
    "        k = torch.square(torch.relu(kw @ xk)) # square relu, primer paper\n",
    "        return r * (vw @ k)\n",
    "```\n",
    "\n",
    "RWKV 4的Channel Mixing的代码实现为：\n",
    "\n",
    "\n",
    "```python\n",
    "@torch.jit.script_method\n",
    "    def channel_mixing(self, x, state, i:int, time_mix_k, time_mix_r, kw, vw, rw):\n",
    "        xk = x * time_mix_k + state[5*i+0] * (1 - time_mix_k)\n",
    "        xr = x * time_mix_r + state[5*i+0] * (1 - time_mix_r)\n",
    "        state[5*i+0] = x\n",
    "        r = torch.sigmoid(rw @ xr)\n",
    "        k = torch.square(torch.relu(kw @ xk)) # square relu, primer paper\n",
    "        return r * (vw @ k)\n",
    "```\n",
    "\n",
    "这里的`i`表示的是RWKV有多少层，在RWKV4的每一层中Channel Mixing记录一个状态，而每一个Time Mixing则记录4个状态，所以一共是5个状态。而RWKV 5中每一层现在记录了`2+self.head_size`个状态，Channel Mixing记录的状态以及计算过程和RWKV 4是完全一样的。"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "976f399a-78ba-4fb2-9b52-d19afda8c5d0",
   "metadata": {},
   "source": [
    "![](./img/01.png)\n",
    "\n",
    "图1：RWKV架构概述。左侧：时间混合和通道混合块；右上角：作为RNN单元的RWKV时间混合块；中下部：前馈模块中的令牌移位模块和Eagle时间混合；右下角：Finch时间混合中的令牌移位模块。所有形状注释为简洁起见假设为单头。虚线箭头（左侧，右上角）表示在Finch中有连接，但在Eagle中没有。"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "ea37d6cd-1348-450b-9f6e-198d7c1d8368",
   "metadata": {},
   "source": [
    "Eagle模型中采用的Token Shift技术：\n",
    "\n",
    "1. **Token Shift**：\n",
    "   - Eagle模型从之前的RWKV模型中采用了Token Shift技术，这类似于大小为2的一维因果卷积（1D causal convolution）。\n",
    "   - 在图1的中心底部可以看到该技术的示意图。\n",
    "\n",
    "2. **线性插值定义**：\n",
    "   - 为了更好地介绍Token Shift技术，定义了一些符号。\n",
    "   - 线性插值（lerp）在时间步$t$和$t-1$之间用于RWKV-4和Eagle Token Shift，定义如下：\n",
    "     \\begin{align*}\n",
    "     \\text{lerp}_{\\Box}(a, b) = a + (b - a) \\odot \\mu_{\\Box}\n",
    "     \\end{align*}\n",
    "   - 其中，每个$\\mu_{\\Box} \\in \\mathbb{R}^D$是一个可学习的向量。\n",
    "\n",
    "3. **Token Shift的功能**：\n",
    "   - Token Shift允许模型学习在每个时间步中分配新信息和旧信息的比例，适用于接受度（receptance）、键（key）、值（value）和门控向量（gate vectors）中的每个通道（$r, k, v, g$），且每个头部（head）独立且唯一地应用这些向量。\n",
    "   - 这使得即使在单层内，一个单独的头部也可以直接将过去和当前的令牌数据累积到这些向量的不同子空间中，从而形成感应头（induction heads）。\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "40ca7c1b-8faa-42d5-8c7d-c4f233063856",
   "metadata": {},
   "source": [
    "在Eagle和Finch模型中，通道混合模块（Channel Mixing module）的设置及其与RWKV-4架构的异同如下：\n",
    "\n",
    "1. **模块一致性**：\n",
    "   - 在Eagle和Finch模型中，通道混合模块与之前的RWKV-4架构基本相同。\n",
    "   - 唯一的区别在于Eagle模型中，通道混合模块的隐藏维度（hidden dimension）从原来的4D减少到了3.5D。\n",
    "\n",
    "2. **减少维度的原因**：\n",
    "   - 这个隐藏维度的减少是为了在Eagle时间混合（Eagle Time Mixing）中引入新的门控权重（gating weights）并确保与之前模型（在相同层数和嵌入维度下）的参数数量相等。\n",
    "\n",
    "3. **Finch模型中的处理**：\n",
    "   - 尽管Finch模型中增加了一些新的LoRA权重参数，但并没有进一步减少隐藏维度。\n",
    "\n",
    "4. **公式一致性**：\n",
    "   - 通道混合的公式与RWKV-4模型相同，为了符号一致性（notational consistency），再次列出这些公式：\n",
    "\n",
    "\\begin{align*}\n",
    "r'_t &= \\text{lerp}_{r'}(x'_t, x'_{t-1}) W_{r'} \\in \\mathbb{R}^D \\quad \\text{(公式10)} \\\\\n",
    "k'_t &= \\text{lerp}_{k'}(x'_t, x'_{t-1}) W_{k'} \\in \\mathbb{R}^{3.5D} \\quad \\text{(公式11)} \\\\\n",
    "v'_t &= \\text{ReLU}(k'_t)^2 W_{v'} \\in \\mathbb{R}^D \\quad \\text{(公式12)} \\\\\n",
    "o'_t &= \\sigma(r'_t) \\odot v'_t \\in \\mathbb{R}^D \\quad \\text{(公式13)}\n",
    "\\end{align*}\n",
    "\n",
    "这些公式描述了在时间步 \\( t \\) 的通道混合操作：\n",
    "- 使用线性插值（lerp）计算 \\( r'_t \\) 和 \\( k'_t \\)。\n",
    "- \\( v'_t \\) 通过 \\( k'_t \\) 的ReLU平方值乘以权重矩阵 \\( W_{v'} \\) 得到。\n",
    "- \\( o'_t \\) 是 \\( r'_t \\) 的激活函数 \\( \\sigma \\) 的输出与 \\( v'_t \\) 的逐元素乘积。\n",
    "\n",
    "其中，3.5D 指的是一种表示维度的方式。在深度学习模型中，D 通常代表模型的隐藏层维度（即嵌入维度或特征空间的维度）。例如，如果模型的隐藏维度是256，那么4D表示这个维度被扩展为4倍，也就是1024。\n",
    "\n",
    "然而，3.5D 是一个不常见的表示方法，通常情况下，我们会看到整数倍的表示（如2D, 4D等）。在这里，3.5D代表的是隐藏维度的3.5倍。\n",
    "\n",
    "具体来说，如果模型的基础维度是D，那么3.5D就表示：\n",
    "\\begin{align*} 3.5D = 3.5 \\times D \\end{align*}\n",
    "\n",
    "假设D是256，那么3.5D就是：\n",
    "\\begin{align*} 3.5 \\times 256 = 896 \\end{align*}\n",
    "\n",
    "所以，3.5D就是指模型在特定层中使用的特征维度是基础维度的3.5倍。在这个文档中，作者提到从4D减少到3.5D，意味着他们减少了某个层或模块的特征维度，以便引入新的门控权重并保持参数数量的一致性。"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "67a5c0d4-57d5-4f9b-a2b7-bddd2250f08a",
   "metadata": {},
   "source": [
    "Eagle时间混合（Eagle Time Mixing）的公式及其操作方法如下：\n",
    "\n",
    "### 公式部分\n",
    "\n",
    "Eagle时间混合的公式如下：\n",
    "\n",
    "\\begin{align*}\n",
    "\\Box_t &= \\text{lerp}_{\\Box}(x_t, x_{t-1}) W_{\\Box}, \\quad \\Box \\in \\{r, k, v, g\\} \\tag{4} \\\\\n",
    "w &= \\exp(-\\exp(\\omega)) \\tag{5} \\\\\n",
    "\\text{wk} \\mathbf{v}_t &= \\text{diag}(u) \\cdot k_t^\\top \\cdot v_t + \\sum_{i=1}^{t-1} \\text{diag}(w)^{t-1-i} \\cdot k_i^\\top \\cdot v_i \\in \\mathbb{R}^{(D/h) \\times (D/h)} \\tag{6} \\\\\n",
    "o_t &= \\text{concat} \\left( \\text{SiLU}(g_t) \\odot \\text{LayerNorm}(r_t \\cdot \\text{wk} \\mathbf{v}_t) \\right) W_o \\in \\mathbb{R}^D \\tag{7}\n",
    "\\end{align*}\n",
    "\n",
    "### 解释部分\n",
    "\n",
    "- **LayerNorm的操作**：LayerNorm在每个头部（head）上独立操作，这相当于在h个组上执行GroupNorm（Wu & He，2018）。值得注意的是，$w$ 是由 $\\omega \\in \\mathbb{R}^{D/h}$ 通过公式 $w = \\exp(-\\exp(\\omega))$ 计算得到的，$\\omega$ 是实际的头部可训练参数。这确保了 $w$ 在区间 (0,1) 内，从而保证 $\\text{diag}(w)$ 是一个收缩矩阵。\n",
    "\n",
    "- **wkv_t 计算**：wkv_t 的注意力计算可以用递归形式写为：\n",
    "  \\begin{align*}\n",
    "  \\text{wk} \\mathbf{v}' &= s + \\text{diag}(u) \\cdot k^\\top \\cdot v \\tag{8} \\\\\n",
    "  s' &= \\text{diag}(w) \\cdot s + k^\\top \\cdot v \\tag{9}\n",
    "  \\end{align*}\n",
    "\n",
    "- **解释RWKV的 wkv_t 项**：RWKV的 wk\\mathbf{v}_t 项可以被认为是归一化 $k^\\top v$ 项的基于衰减的等价物。值得注意的是，对于给定的头部 $j$，递归状态 $s$ 是 $k^\\top v$ 的和，其中 $s$ 的每个通道在每个时间步通过相应的 $w$ 通道单独衰减。在应用接受度向量、门控和输出权重之前，当前令牌的 $k^\\top v$ 被乘以一个每通道的学习提升 $u$ 并与状态相加，见图1右上角。这给当前令牌相对于包含在衰减状态历史中的过去令牌的和一个特殊的处理。接受度乘以这个和，类似于线性注意力中的查询项。\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "78a86c2f-962d-4826-baf4-bc19bc40b6e3",
   "metadata": {},
   "source": [
    "这里的最大的改进应该是现在的计算是分成了`H = self.n_head`个头，然后每个头的计算结果都被存到了state里。相比于RWKV-4，这种改进可以类比于Transformer的单头自注意力机制改到多头注意力机制。\n",
    "```python\n",
    "    @MyFunction\n",
    "    def time_mixing(self, x, state, i:int, time_mix_k, time_mix_v, time_mix_r, time_mix_g, time_first, time_decay, kw, vw, rw, gw, ow, ln_w, ln_b):\n",
    "        H = self.n_head\n",
    "        S = self.head_size\n",
    "\n",
    "        i1 = (2+S)*i+1\n",
    "        xk = x * time_mix_k + state[i1] * (1 - time_mix_k)\n",
    "        xv = x * time_mix_v + state[i1] * (1 - time_mix_v)\n",
    "        xr = x * time_mix_r + state[i1] * (1 - time_mix_r)\n",
    "        xg = x * time_mix_g + state[i1] * (1 - time_mix_g)\n",
    "        state[i1] = x\n",
    "\n",
    "        r = (rw @ xr).view(H, 1, S)\n",
    "        k = (kw @ xk).view(H, S, 1)\n",
    "        v = (vw @ xv).view(H, 1, S)\n",
    "        g = F.silu(gw @ xg)\n",
    "\n",
    "        s = state[(2+S)*i+2:(2+S)*(i+1), :].reshape(H, S, S)\n",
    "\n",
    "        x = torch.zeros(H, S)\n",
    "        a = k @ v\n",
    "        x = r @ (time_first * a + s)\n",
    "        s = a + time_decay * s\n",
    "    \n",
    "        state[(2+S)*i+2:(2+S)*(i+1), :] = s.reshape(S, -1)\n",
    "        x = x.flatten()\n",
    "\n",
    "        x = F.group_norm(x.unsqueeze(0), num_groups=H, weight=ln_w, bias=ln_b, eps = 64e-5).squeeze(0) * g # same as gn(x/8, eps=1e-5)\n",
    "        return ow @ x\n",
    "```"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "id": "bd093a96-fdc5-460d-b39f-fe3735795b42",
   "metadata": {},
   "outputs": [],
   "source": [
    "class RWKV_RNN(MyModule):\n",
    "    def __init__(self, args):\n",
    "        super().__init__()\n",
    "        self.args = args\n",
    "        self.eval() # set torch to inference mode\n",
    "        \n",
    "        w = torch.load(args.MODEL_NAME + '.pth', map_location='cpu')\n",
    "        for k in w.keys():\n",
    "            w[k] = w[k].float() # convert to f32 type\n",
    "            if      '.time_' in k: w[k] = w[k].squeeze()\n",
    "            if '.time_decay' in k: w[k] = torch.exp(-torch.exp(w[k])).unsqueeze(-1)\n",
    "            if '.time_faaaa' in k: w[k] = w[k].unsqueeze(-1)\n",
    "\n",
    "        self.n_head = w['blocks.0.att.time_decay'].shape[0]\n",
    "        self.head_size = w['blocks.0.ln1.weight'].shape[0] // self.n_head\n",
    "        \n",
    "        self.w = types.SimpleNamespace() # set self.w from w\n",
    "        self.w.blocks = {}\n",
    "        for k in w.keys(): # example: \"blocks.0.att.time_first\" => self.w.blocks[0].att.time_first\n",
    "            parts = k.split('.')\n",
    "            last = parts.pop()\n",
    "            here = self.w\n",
    "            for p in parts:\n",
    "                if p.isdigit():\n",
    "                    p = int(p)\n",
    "                    if p not in here: here[p] = types.SimpleNamespace()\n",
    "                    here = here[p]\n",
    "                else:\n",
    "                    if not hasattr(here, p): setattr(here, p, types.SimpleNamespace())\n",
    "                    here = getattr(here, p)\n",
    "            setattr(here, last, w[k])\n",
    "\n",
    "    def layer_norm(self, x, w):\n",
    "        return F.layer_norm(x, (self.args.n_embd,), weight=w.weight, bias=w.bias)\n",
    "\n",
    "    @MyFunction\n",
    "    def channel_mixing(self, x, state, i:int, time_mix_k, time_mix_r, kw, vw, rw):\n",
    "        i0 = (2+self.head_size)*i+0\n",
    "        xk = x * time_mix_k + state[i0] * (1 - time_mix_k)\n",
    "        xr = x * time_mix_r + state[i0] * (1 - time_mix_r)\n",
    "        state[i0] = x\n",
    "        r = torch.sigmoid(rw @ xr)\n",
    "        k = torch.square(torch.relu(kw @ xk)) # square relu, primer paper\n",
    "        return r * (vw @ k)\n",
    "\n",
    "    @MyFunction\n",
    "    def time_mixing(self, x, state, i:int, time_mix_k, time_mix_v, time_mix_r, time_mix_g, time_first, time_decay, kw, vw, rw, gw, ow, ln_w, ln_b):\n",
    "        H = self.n_head\n",
    "        S = self.head_size\n",
    "\n",
    "        i1 = (2+S)*i+1\n",
    "        xk = x * time_mix_k + state[i1] * (1 - time_mix_k)\n",
    "        xv = x * time_mix_v + state[i1] * (1 - time_mix_v)\n",
    "        xr = x * time_mix_r + state[i1] * (1 - time_mix_r)\n",
    "        xg = x * time_mix_g + state[i1] * (1 - time_mix_g)\n",
    "        state[i1] = x\n",
    "\n",
    "        r = (rw @ xr).view(H, 1, S)\n",
    "        k = (kw @ xk).view(H, S, 1)\n",
    "        v = (vw @ xv).view(H, 1, S)\n",
    "        g = F.silu(gw @ xg)\n",
    "\n",
    "        s = state[(2+S)*i+2:(2+S)*(i+1), :].reshape(H, S, S)\n",
    "\n",
    "        x = torch.zeros(H, S)\n",
    "        a = k @ v\n",
    "        x = r @ (time_first * a + s)\n",
    "        s = a + time_decay * s\n",
    "    \n",
    "        state[(2+S)*i+2:(2+S)*(i+1), :] = s.reshape(S, -1)\n",
    "        x = x.flatten()\n",
    "\n",
    "        x = F.group_norm(x.unsqueeze(0), num_groups=H, weight=ln_w, bias=ln_b, eps = 64e-5).squeeze(0) * g # same as gn(x/8, eps=1e-5)\n",
    "        return ow @ x\n",
    "\n",
    "    def forward(self, token, state):\n",
    "        with torch.no_grad():\n",
    "            if state == None:\n",
    "                state = torch.zeros(self.args.n_layer * (2+self.head_size), self.args.n_embd)\n",
    "            \n",
    "            x = self.w.emb.weight[token]\n",
    "            x = self.layer_norm(x, self.w.blocks[0].ln0)\n",
    "            for i in range(self.args.n_layer):\n",
    "                # print(i)\n",
    "                att = self.w.blocks[i].att\n",
    "                x = x + self.time_mixing(self.layer_norm(x, self.w.blocks[i].ln1), state, i, \n",
    "                    att.time_mix_k, att.time_mix_v, att.time_mix_r, att.time_mix_g, att.time_faaaa, att.time_decay, \n",
    "                    att.key.weight, att.value.weight, att.receptance.weight, att.gate.weight, att.output.weight,\n",
    "                    att.ln_x.weight, att.ln_x.bias)\n",
    "                ffn = self.w.blocks[i].ffn\n",
    "                x = x + self.channel_mixing(self.layer_norm(x, self.w.blocks[i].ln2), state, i, \n",
    "                    ffn.time_mix_k, ffn.time_mix_r, \n",
    "                    ffn.key.weight, ffn.value.weight, ffn.receptance.weight)\n",
    "            \n",
    "            x = self.w.head.weight @ self.layer_norm(x, self.w.ln_out)\n",
    "            return x.float(), state"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 32,
   "id": "a330cd34-7ed0-4a6c-92a3-19797d34ee77",
   "metadata": {},
   "outputs": [],
   "source": [
    "# context = \"Q:Do you know datawhalechina?\\nA:\"\n",
    "context = '\\nQ:DataWhalechina is an organization founded at Shanghai Jiao Tong University that helps learners learn artificial intelligence. How do you think of it?'"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 33,
   "id": "ad824161-413d-460c-9ffe-9dbfb739f86b",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "'/data1/ckw/RWKV-5-World-0.4B-v2-20231113-ctx4096'"
      ]
     },
     "execution_count": 33,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "args.MODEL_NAME"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 34,
   "id": "f0e2f841-4cda-48d4-b055-7adf00f2fe73",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(24, 1024)"
      ]
     },
     "execution_count": 34,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "args.n_layer,args.n_embd"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 35,
   "id": "aba8a4d4-9a77-4191-a7ef-d5e6100ca3c1",
   "metadata": {},
   "outputs": [],
   "source": [
    "# args.n_layer = 24\n",
    "# args.n_embd = 1024"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 36,
   "id": "dd44f7bc-e8d6-4242-beb5-89a866990751",
   "metadata": {},
   "outputs": [],
   "source": [
    "# args.n_layer = 12\n",
    "# args.n_embd = 768"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 37,
   "id": "2a96d9dc-8b5e-40cc-bb36-24c9bdeac29e",
   "metadata": {},
   "outputs": [],
   "source": [
    "# args.MODEL_NAME='../models/rwkv-5-world-1b5'"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 38,
   "id": "b7d07606-31b4-4c21-9f89-554d89c2c866",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "Using CPU. Loading /data1/ckw/RWKV-5-World-0.4B-v2-20231113-ctx4096 ...\n",
      "\n",
      "Preprocessing context (slow version. see v2/rwkv/model.py for fast version)\n"
     ]
    }
   ],
   "source": [
    "print(f'\\nUsing CPU. Loading {args.MODEL_NAME} ...')\n",
    "model = RWKV_RNN(args)\n",
    "\n",
    "print(f'\\nPreprocessing context (slow version. see v2/rwkv/model.py for fast version)')\n",
    "init_state = None"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 39,
   "id": "ce42cfad-0274-4d5d-950d-fb89ff11ed2c",
   "metadata": {},
   "outputs": [],
   "source": [
    "init_state = None"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 40,
   "id": "3e02a81c-1447-4936-a241-4d00ecf8e862",
   "metadata": {},
   "outputs": [],
   "source": [
    "LENGTH_PER_TRIAL=1024"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 41,
   "id": "4a00ea05-d6fd-4052-b13a-8107fb268420",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "\n",
      "--[ Trial 0 ]----------------- \n",
      "Q:DataWhalechina is an organization founded at Shanghai Jiao Tong University that helps learners learn artificial intelligence. How do you think of it?\n",
      "QI: I think that the group of students is actually the whole AI community.\n",
      "Q: In the first episode, how do you think you, a student, can use AI to solve a problem?\n",
      "QI: It's a great opportunity to help develop and build knowledge, so that if we see AI problems, we can help solve them.\n",
      "Q: How do you think that students can also participate in the teaching of AI?\n",
      "QI: It is very important to let the students to think that there is an AI problem, and we can solve it by teaching AI.\n",
      "Q: How do you think the research that we did on AI can be used to develop AI technologies?\n",
      "QI: The research is interesting and it can be used to develop AI technologies.\n",
      "Q: Do you think that students can learn from your research?\n",
      "QI: I think so.\n",
      "Q: You also talk about the use of AI in real-life applications. What do you think of that?\n",
      "QI: I think it's a good thing to see.\n",
      "Q: What are the major challenges that you see as being faced by the AI community?\n",
      "QI: One is how to find data that can help us solve problems. The other is how to find a good dataset.\n",
      "Q: You also talk about how we should deal with the big data problem. How do you think about that?\n",
      "QI: We should not think that it is impossible to handle big data. There are a lot of big data, but there is a problem of how to handle them.\n",
      "Q: What is the role of AI in industry?\n",
      "QI: AI plays an important role in industry. AI has helped us improve the quality of services. We have a lot of new applications that we are using AI to solve.\n",
      "Q: How do you think about AI and humans in the future?\n",
      "QI: AI is not just for humans. It is also used for us to learn, for example.\n",
      "Q: What do you think about the use of AI in the field of tourism?\n",
      "QI: It's not that easy to use AI in tourism. There are so many problems.\n",
      "Q: Do you think that AI will be a part of tourism in the future?\n",
      "QI: I think so. It is very important for us to see.\n",
      "Q: What do you think about AI and data sharing?\n",
      "QI: It is not that easy to use AI in data sharing.\n",
      "Q: What are the ways that you see AI in tourism?\n",
      "QI: AI can be used to solve problems.\n",
      "Q: How do you think about the relationship between AI and data?\n",
      "QI: We need to use AI in the future to help us solve problems.\n",
      "Q: How do you think about the relationship between AI and data sharing?\n",
      "Q: What do you think about the future of AI?\n",
      "Q: What are the main issues that AI is facing?\n",
      "Q: What are the biggest challenges that you see in the field of AI?\n",
      "Q: What do you think about the future of AI?\n",
      "Q: What are the biggest challenges that you see in the field of AI?\n",
      "Q: What are the main trends that you see in the field of AI?\n",
      "Q: What are the main challenges that you see in the field of AI?\n",
      "Q: What are the main trends that you see in the field of AI?\n",
      "Q: What are the main trends that you see in the field of AI?\n",
      "Q: What are the main trends that you see in the field of AI?\n",
      "Q: What are the main trends that you see in the field of AI?\n",
      "Q: What are the main trends that you see in the field of AI?\n",
      "Q: What are the main trends that you see in the field of AI?\n",
      "Q: What are the main trends that you see in the field of AI?\n",
      "Q: What are the main trends that you see in the field of AI?\n",
      "Q: What are the main trends that you see in the field of AI?\n",
      "Q: What are the main trends that you see in the field of AI?\n",
      "Q: What are the main trends that you see in the field of AI?\n",
      "Q: What are the main trends that you see in the field of AI?\n",
      "Q: What are the main trends that you see in the field of AI?\n",
      "Q: What are the main trends that you see in the field of AI?\n",
      "Q: What are the main trends that you see in the field of AI?\n",
      "Q: What are the main trends that you see in the field of AI?\n",
      "Q: What are the main trends that you see in the field of AI?\n",
      "Q: What are the main trends that you see in the field of AI?\n",
      "Q: What are the main trends that you see in the field of AI?\n",
      "Q: What are the main trends that you see in the field of AI?\n",
      "\n",
      "--[ Trial 1 ]----------------- \n",
      "Q:DataWhalechina is an organization founded at Shanghai Jiao Tong University that helps learners learn artificial intelligence. How do you think of it?\n",
      "M: We are always looking for data to make sure that we are doing the right thing. We are currently looking at how to do this through the webinars and learning events. We have had speakers from different areas, such as from Silicon Valley, who have participated in the series. The current speaker, Marco Aurelio, was from Hong Kong. He was doing a presentation on Artificial Intelligence.\n",
      "Q:How do you think of the audience that you are aiming to reach?\n",
      "M: We are aiming at the general audience. We are also targeting people in the financial industry, who are also interested in artificial intelligence.\n",
      "Q:What are your biggest challenges?\n",
      "M: One of the biggest challenges is that the audience is very educated. They know about artificial intelligence and data. But the difficulty is that we have to explain the whole technology to them.\n",
      "Q:How do you see the future of artificial intelligence?\n",
      "M: It is an interesting future. It is really interesting. We are starting to see many different developments. The technology is really getting better and better. There are different ways of data that are being created. We have the development of machines to pick words and sentences and machines to make the machines think.\n",
      "Q:How do you see the future of Artificial Intelligence?\n",
      "M: We are constantly working on how to make the future of artificial intelligence more human-like.\n",
      "Tags: dataWhalechina\n",
      "Previous PostFuture is one of the hottest topics in Artificial Intelligence\n",
      "Next PostOpinions about the future of Artificial Intelligence are changing\n",
      "Cotton Developer News: Hands-On With Artificial Intelligence\n",
      "Headlines from the data Whalechina Network: October 6, 2019\n",
      "DataWhalechina Network: July 30, 2019\n",
      "Cotton Developer News: July 22, 2019\n",
      "Headlines from the data Whalechina Network: June 22, 2019\n",
      "DataWhalechina Network: May 18, 2019\n",
      "Archives Select Month July 2019 June 2019 May 2019 April 2019 March 2019 February 2019 January 2019 December 2018 November 2018 October 2018 September 2018 August 2018 July 2018 June 2018 May 2018 April 2018 March 2018 February 2018 January 2018 December 2017 November 2017 October 2017 September 2017 August 2017 July 2017 June 2017 May 2017 April 2017 March 2017 February 2017 January 2017 December 2016 November 2016 October 2016 September 2016 August 2016 July 2016 June 2016 May 2016 April 2016 March 2016 February 2016 January 2016 December 2015 November 2015 October 2015 September 2015 August 2015 July 2015 June 2015 May 2015 April 2015 March 2015 February 2015 January 2015 December 2014 November 2014 October 2014 September 2014 August 2014 July 2014 June 2014 May 2014 April 2014 March 2014 February 2014 January 2014 December 2013 November 2013 October 2013 September 2013 August 2013 July 2013 June 2013 May 2013 April 2013 March 2013 February 2013 January 2013 December 2012 November 2012 October 2012 September 2012 August 2012 July 2012 June 2012 May 2012 April 2012 March 2012 February 2012 January 2012 December 2011 November 2011 October 2011 September 2011 August 2011 July 2011 June 2011 May 2011 April 2011 March 2011 February 2011 January 2011 December 2010 November 2010 October 2010 September 2010 August 2010 July 2010 June 2010 May 2010 April 2010 March 2010 February 2010 January 2010 December 2009 November 2009 October 2009 September 2009 August 2009 July 2009 June 2009 May 2009 April 2009 March 2009 February 2009 January 2009 December 2008 November 2008 October 2008 September 2008 August 2008 July 2008 June 2008 May 2008 April 2008 March 2008 February 2008 January 2008 December 2007 November 2007 October 2007 September 2007 August 2007 July 2007 June 2007 May 2007 April 2007 March 2007 February 2007 January 2007 December 2006 November 2006 October 2006 September 2006 August 2006 July 2006 June 2006 May 2006 April 2006 March 2006 February 2006 January 2006 December 2005 November 2005 October 2005 September 2005 August 2005 July 2005 June 2005 May 2005 April 2005 March 2005 February 2005 January 2005 December 2004 November 2004 October 2004 September 2004 August 2004 July 2004 June 2004 May 2004 April 2004 March 2004 February 2004 January 2004 December 2003 November 2003 October 2003 September 2003 August 2003 July 2003 June 2003 May 2003 April 2003 March 2003 February 2003 January 2003 December 2002 November 2002 October 2002 September 2002 August 2002 July 2002\n",
      "\n",
      "--[ Trial 2 ]----------------- \n",
      "Q:DataWhalechina is an organization founded at Shanghai Jiao Tong University that helps learners learn artificial intelligence. How do you think of it?\n",
      "Q:As AI continues to grow, what are some of the most promising applications of artificial intelligence?\n",
      "Q:How do you think artificial intelligence will affect the future of AI?\n",
      "Q:How does AI's role in education differ from the way it was used in the past?\n",
      "Q:What are some of the challenges AI will face in the future?\n",
      "Q:What is your vision for AI?\n",
      "Q:What are your key takeaways from this conference?\n",
      "Q:What do you hope to accomplish with AI?\n",
      "Q:What are your goals for the future of AI?\n",
      "Q:What are your current trends and plans for AI?\n",
      "Q:How can AI be applied in education?\n",
      "Q:What do you think will be the biggest impact of AI on education?\n",
      "Q:What is your vision for AI in the future?\n",
      "Q:How does AI change the way we teach and learn?\n",
      "Q:What are your hopes for the future of AI?\n",
      "Q:How does AI's role in education differ from the way it was used in the past?\n",
      "Q:What are some of the challenges AI will face in the future?\n",
      "Q:What is your vision for the future of AI?\n",
      "Q:What are your key takeaways from this conference?\n",
      "Q:What is your vision for the future of AI?\n",
      "Q:What are your goals for the future of AI?\n",
      "Q:What are some of the biggest challenges AI will face in the future?\n",
      "Q:What is your vision for AI's role in education?\n",
      "Q:What are your goals for the future of AI?\n",
      "Q:What are your hopes for the future of AI?\n",
      "Q:What are your key takeaways from this conference?\n",
      "Q:What is your vision for AI's role in education?\n",
      "Q:What are your goals for the future of AI?\n",
      "Q:What are some of the biggest challenges AI will face in the future?\n",
      "Q:What is your vision for AI's role in education?\n",
      "Q:What are your goals for the future of AI?\n",
      "Q:What are some of the biggest challenges AI will face in the future?\n",
      "Q:What is your vision for AI's role in education?\n",
      "Q:What are your goals for the future of AI?\n",
      "Q:What are some of the biggest challenges AI will face in the future?\n",
      "Q:What is your vision for AI's role in education?\n",
      "Q:What are your goals for the future of AI?\n",
      "Q:What are some of the biggest challenges AI will face in the future?\n",
      "Q:What is your vision for AI's role in education?\n",
      "Q:What are your goals for the future of AI?\n",
      "Q:What are some of the biggest challenges AI will face in the future?\n",
      "Q:What is your vision for AI's role in education?\n",
      "Q:What are your goals for the future of AI?\n",
      "Q:What are some of the biggest challenges AI will face in the future?\n",
      "Q:What is your vision for AI's role in education?\n",
      "Q:What are your goals for the future of AI?\n",
      "Q:What are some of the biggest challenges AI will face in the future?\n",
      "Q:What is your vision for AI's role in education?\n",
      "Q:What are your goals for the future of AI?\n",
      "Q:What are some of the biggest challenges AI will face in the future?\n",
      "Q:What is your vision for AI's role in education?\n",
      "Q:What are your goals for the future of AI?\n",
      "Q:What are some of the biggest challenges AI will face in the future?\n",
      "Q:What is your vision for AI's role in education?\n",
      "Q:What are your goals for the future of AI?\n",
      "Q:What are some of the biggest challenges AI will face in the future?\n",
      "Q:What is your vision for AI's role in education?\n",
      "Q:What are your goals for the future of AI?\n",
      "Q:What are some of the biggest challenges AI will face in the future?\n",
      "Q:What is your vision for AI's role in education?\n",
      "Q:What are your goals for the future of AI?\n",
      "Q:What are some of the biggest challenges AI will face in the future?\n",
      "Q:What is your vision for AI's role in education?\n",
      "Q:What are your goals for the future of AI?\n",
      "Q:What are some of the biggest challenges AI will face in the future?\n",
      "Q:What is your vision for AI's role in education?\n",
      "Q:What are your goals for the future of AI?\n",
      "Q:What are some of the biggest challenges AI will face in the future?\n",
      "Q:What is your vision for AI's role in education?\n",
      "Q:What are your goals for the future of AI?\n",
      "Q:What are some of the biggest challenges AI will face in the future?\n",
      "Q\n",
      "\n"
     ]
    }
   ],
   "source": [
    "for token in tokenizer.encode(context):\n",
    "    init_out, init_state = model.forward(token, init_state)\n",
    "\n",
    "for TRIAL in range(NUM_TRIALS):\n",
    "    print(f'\\n\\n--[ Trial {TRIAL} ]-----------------', context, end=\"\")\n",
    "    all_tokens = []\n",
    "    out_last = 0\n",
    "    out, state = init_out.clone(), init_state.clone()\n",
    "    for i in range(LENGTH_PER_TRIAL):\n",
    "        token = sample_logits(out, TEMPERATURE, TOP_P)\n",
    "        all_tokens += [token]\n",
    "        try:\n",
    "            tmp = tokenizer.decode(all_tokens[out_last:])\n",
    "            if '\\ufffd' not in tmp: # only print when we have a valid utf-8 string\n",
    "                print(tmp, end=\"\", flush=True)\n",
    "                out_last = i + 1\n",
    "        except:\n",
    "            pass\n",
    "        out, state = model.forward(token, state)       \n",
    "print('\\n')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a3d3eaf3-252a-43da-9414-e1c6f6c681fc",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "kewei-ai",
   "language": "python",
   "name": "kewei-ai"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.5"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
